#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import numpy as np
import gym
import abc
import math
from osim.env import L2M2019Env
from parl.utils import logger


class FirstTarget(gym.Wrapper):
    def __init__(self, env):
        assert (isinstance(env, RewardShaping)), type(env)
        gym.Wrapper.__init__(self, env)

    def step(self, action, **kwargs):
        obs, r, done, info = self.env.step(action, **kwargs)
        # early stop condition
        if info['target_changed']:
            info['timeout'] = True
            done = True
            logger.warning(
                '[FirstTarget Wrapper] early stop since first target is finished.'
            )
        return obs, r, done, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)


class ActionScale(gym.Wrapper):
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)

    def step(self, action, **kwargs):
        action = (np.copy(action) + 1.0) * 0.5
        action = np.clip(action, 0.0, 1.0)
        return self.env.step(action, **kwargs)

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)


class FrameSkip(gym.Wrapper):
    def __init__(self, env, skip_num=4):
        gym.Wrapper.__init__(self, env)
        self.skip_num = skip_num
        self.frame_count = 0

    def step(self, action, **kwargs):
        r = 0.0
        merge_info = {}
        for k in range(self.skip_num):
            self.frame_count += 1
            obs, reward, done, info = self.env.step(action, **kwargs)
            r += reward

            for key in info.keys():
                if 'reward' in key:
                    merge_info[key] = merge_info.get(key, 0.0) + info[key]
                else:
                    merge_info[key] = info[key]

            if info['target_changed']:
                logger.warning(
                    "[FrameSkip Wrapper] early break since target is changed")
                break

            if done:
                break
        merge_info['frame_count'] = self.frame_count
        return obs, r, done, merge_info

    def reset(self, **kwargs):
        self.frame_count = 0
        return self.env.reset(**kwargs)


class RewardShaping(gym.Wrapper):
    """ A wrapper for reward shaping, note this wrapper must be the first wrapper """

    def __init__(self, env, max_timelimit):
        logger.info("[RewardShaping]type:{}, max_timelimit: {}".format(
            type(env), max_timelimit))

        self.max_timelimit = max_timelimit

        self.step_count = 0
        self.pre_state_desc = None
        self.last_target_vel = None
        gym.Wrapper.__init__(self, env)

    @abc.abstractmethod
    def reward_shaping(self, state_desc, reward, done, action, info):
        """define your own reward computation function
        Args:
            state_desc(dict): state description for current model
            reward(scalar): generic reward generated by env
            done(bool): generic done flag generated by env
            info(dict): generic info generated by env
        """
        pass

    def step(self, action, **kwargs):
        self.step_count += 1
        obs, r, done, info = self.env.step(action, **kwargs)

        info = self.reward_shaping(obs, r, done, action, info)

        target_vel = np.linalg.norm(
            [obs['v_tgt_field'][0][5][5], obs['v_tgt_field'][1][5][5]])
        info['target_changed'] = False
        if self.last_target_vel is not None:
            if np.abs(target_vel - self.last_target_vel) > 0.2:
                info['target_changed'] = True
        self.last_target_vel = target_vel

        assert 'shaping_reward' in info
        timeout = False
        if self.step_count >= self.max_timelimit:
            timeout = True

        info['timeout'] = timeout
        self.pre_state_desc = obs
        return obs, r, done, info

    def reset(self, **kwargs):
        self.step_count = 0
        self.last_target_vel = None
        obs = self.env.reset(**kwargs)
        self.pre_state_desc = obs
        return obs


class FinalReward(RewardShaping):
    """ A reward shaping wrapper"""

    def __init__(self, env, max_timelimit, vel_penalty_coeff,
                 muscle_penalty_coeff, penalty_coeff):
        RewardShaping.__init__(self, env, max_timelimit=max_timelimit)

        self.vel_penalty_coeff = vel_penalty_coeff
        self.muscle_penalty_coeff = muscle_penalty_coeff
        self.penalty_coeff = penalty_coeff

    def reward_shaping(self, state_desc, env_reward, done, action, info):
        # Reward for not falling down
        reward = 10.0

        yaw = state_desc['joint_pos']['ground_pelvis'][2]
        current_v_x, current_v_z = rotate_frame(
            state_desc['body_vel']['pelvis'][0],
            state_desc['body_vel']['pelvis'][2], yaw)
        # leftward (Attention!!!)
        current_v_z = -current_v_z

        # current relative target theta
        target_v_x, target_v_z = state_desc['v_tgt_field'][0][5][
            5], state_desc['v_tgt_field'][1][5][5]

        vel_penalty = np.linalg.norm(
            [target_v_x - current_v_x, target_v_z - current_v_z])

        muscle_penalty = 0
        for muscle in sorted(state_desc['muscles'].keys()):
            muscle_penalty += np.square(
                state_desc['muscles'][muscle]['activation'])

        ret_r = reward - (vel_penalty * self.vel_penalty_coeff + muscle_penalty
                          * self.muscle_penalty_coeff) * self.penalty_coeff

        info = {
            'shaping_reward': ret_r,
            'env_reward': env_reward,
        }
        return info


class ObsTranformerBase(gym.Wrapper):
    def __init__(self, env, max_timelimit, skip_num=4):
        gym.Wrapper.__init__(self, env)
        self.max_timelimit = max_timelimit
        self.skip_num = skip_num

        self.step_fea = self.max_timelimit

    def get_observation(self, state_desc):
        obs = self._get_observation(state_desc)
        return obs

    @abc.abstractmethod
    def _get_observation(self, state_desc):
        pass

    def feature_normalize(self, obs, mean, std, duplicate_id):
        scaler_len = mean.shape[0]
        assert obs.shape[0] >= scaler_len
        obs[:scaler_len] = (obs[:scaler_len] - mean) / std
        final_obs = []
        for i in range(obs.shape[0]):
            if i not in duplicate_id:
                final_obs.append(obs[i])
        return np.array(final_obs)

    def step(self, action, **kwargs):
        obs, r, done, info = self.env.step(action, **kwargs)

        if info['target_changed']:
            # reset step_fea when change target
            self.step_fea = self.max_timelimit

        self.step_fea -= self.skip_num

        obs = self.get_observation(obs)
        return obs, r, done, info

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        self.step_fea = self.max_timelimit
        obs = self.get_observation(obs)
        return obs


class OfficialObs(ObsTranformerBase):
    """Basically same feature processing as official.

    Reference: https://github.com/stanfordnmbl/osim-rl/blob/master/osim/env/osim.py
    """

    MASS = 75.16460000000001  # 11.777 + 2*(9.3014 + 3.7075 + 0.1 + 1.25 + 0.2166) + 34.2366
    G = 9.80665  # from gait1dof22muscle

    LENGTH0 = 1  # leg lengt

    Fmax = {
        'r_leg': {
            'HAB': 4460.290481,
            'HAD': 3931.8,
            'HFL': 2697.344262,
            'GLU': 3337.583607,
            'HAM': 4105.465574,
            'RF': 2191.74098360656,
            'VAS': 9593.95082,
            'BFSH': 557.11475409836,
            'GAS': 4690.57377,
            'SOL': 7924.996721,
            'TA': 2116.818162
        },
        'l_leg': {
            'HAB': 4460.290481,
            'HAD': 3931.8,
            'HFL': 2697.344262,
            'GLU': 3337.583607,
            'HAM': 4105.465574,
            'RF': 2191.74098360656,
            'VAS': 9593.95082,
            'BFSH': 557.11475409836,
            'GAS': 4690.57377,
            'SOL': 7924.996721,
            'TA': 2116.818162
        }
    }
    lopt = {
        'r_leg': {
            'HAB': 0.0845,
            'HAD': 0.087,
            'HFL': 0.117,
            'GLU': 0.157,
            'HAM': 0.069,
            'RF': 0.076,
            'VAS': 0.099,
            'BFSH': 0.11,
            'GAS': 0.051,
            'SOL': 0.044,
            'TA': 0.068
        },
        'l_leg': {
            'HAB': 0.0845,
            'HAD': 0.087,
            'HFL': 0.117,
            'GLU': 0.157,
            'HAM': 0.069,
            'RF': 0.076,
            'VAS': 0.099,
            'BFSH': 0.11,
            'GAS': 0.051,
            'SOL': 0.044,
            'TA': 0.068
        }
    }

    def __init__(self, env, max_timelimit, skip_num=4):
        ObsTranformerBase.__init__(self, env, max_timelimit, skip_num)
        data = np.load('./official_obs_scaler.npz')
        self.mean, self.std, self.duplicate_id = data['mean'], data[
            'std'], data['duplicate_id']
        self.duplicate_id = self.duplicate_id.astype(np.int32).tolist()

    def _get_observation_dict(self, state_desc):
        obs_dict = {}

        obs_dict['v_tgt_field'] = state_desc['v_tgt_field']

        # pelvis state (in local frame)
        obs_dict['pelvis'] = {}
        obs_dict['pelvis']['height'] = state_desc['body_pos']['pelvis'][1]
        obs_dict['pelvis']['pitch'] = -state_desc['joint_pos'][
            'ground_pelvis'][0]  # (+) pitching forward
        obs_dict['pelvis']['roll'] = state_desc['joint_pos']['ground_pelvis'][
            1]  # (+) rolling around the forward axis (to the right)
        yaw = state_desc['joint_pos']['ground_pelvis'][2]
        dx_local, dy_local = rotate_frame(state_desc['body_vel']['pelvis'][0],
                                          state_desc['body_vel']['pelvis'][2],
                                          yaw)
        dz_local = state_desc['body_vel']['pelvis'][1]
        obs_dict['pelvis']['vel'] = [
            dx_local,  # (+) forward
            -dy_local,  # (+) leftward
            dz_local,  # (+) upward
            -state_desc['joint_vel']['ground_pelvis']
            [0],  # (+) pitch angular velocity
            state_desc['joint_vel']['ground_pelvis']
            [1],  # (+) roll angular velocity
            state_desc['joint_vel']['ground_pelvis'][2]
        ]  # (+) yaw angular velocity

        # leg state
        for leg, side in zip(['r_leg', 'l_leg'], ['r', 'l']):
            obs_dict[leg] = {}
            grf = [
                f / (self.MASS * self.G)
                for f in state_desc['forces']['foot_{}'.format(side)][0:3]
            ]  # forces normalized by bodyweight
            grm = [
                m / (self.MASS * self.G)
                for m in state_desc['forces']['foot_{}'.format(side)][3:6]
            ]  # forces normalized by bodyweight
            grfx_local, grfy_local = rotate_frame(-grf[0], -grf[2], yaw)
            if leg == 'r_leg':
                obs_dict[leg]['ground_reaction_forces'] = [
                    grfx_local,  # (+) forward
                    grfy_local,  # (+) lateral (rightward)
                    -grf[1]
                ]  # (+) upward
            if leg == 'l_leg':
                obs_dict[leg]['ground_reaction_forces'] = [
                    grfx_local,  # (+) forward
                    -grfy_local,  # (+) lateral (leftward)
                    -grf[1]
                ]  # (+) upward

            # joint angles
            obs_dict[leg]['joint'] = {}
            obs_dict[leg]['joint']['hip_abd'] = -state_desc['joint_pos'][
                'hip_{}'.format(side)][1]  # (+) hip abduction
            obs_dict[leg]['joint']['hip'] = -state_desc['joint_pos'][
                'hip_{}'.format(side)][0]  # (+) extension
            obs_dict[leg]['joint']['knee'] = state_desc['joint_pos'][
                'knee_{}'.format(side)][0]  # (+) extension
            obs_dict[leg]['joint']['ankle'] = -state_desc['joint_pos'][
                'ankle_{}'.format(side)][0]  # (+) extension
            # joint angular velocities
            obs_dict[leg]['d_joint'] = {}
            obs_dict[leg]['d_joint']['hip_abd'] = -state_desc['joint_vel'][
                'hip_{}'.format(side)][1]  # (+) hip abduction
            obs_dict[leg]['d_joint']['hip'] = -state_desc['joint_vel'][
                'hip_{}'.format(side)][0]  # (+) extension
            obs_dict[leg]['d_joint']['knee'] = state_desc['joint_vel'][
                'knee_{}'.format(side)][0]  # (+) extension
            obs_dict[leg]['d_joint']['ankle'] = -state_desc['joint_vel'][
                'ankle_{}'.format(side)][0]  # (+) extension

            # muscles
            for MUS, mus in zip([
                    'HAB', 'HAD', 'HFL', 'GLU', 'HAM', 'RF', 'VAS', 'BFSH',
                    'GAS', 'SOL', 'TA'
            ], [
                    'abd', 'add', 'iliopsoas', 'glut_max', 'hamstrings',
                    'rect_fem', 'vasti', 'bifemsh', 'gastroc', 'soleus',
                    'tib_ant'
            ]):
                obs_dict[leg][MUS] = {}
                obs_dict[leg][MUS]['f'] = state_desc['muscles']['{}_{}'.format(
                    mus, side)]['fiber_force'] / self.Fmax[leg][MUS]
                obs_dict[leg][MUS]['l'] = state_desc['muscles']['{}_{}'.format(
                    mus, side)]['fiber_length'] / self.lopt[leg][MUS]
                obs_dict[leg][MUS]['v'] = state_desc['muscles']['{}_{}'.format(
                    mus, side)]['fiber_velocity'] / self.lopt[leg][MUS]

        return obs_dict

    def _get_observation(self, state_desc):

        obs_dict = self._get_observation_dict(state_desc)
        res = []

        # target velocity field (in body frame)
        #v_tgt = np.ndarray.flatten(obs_dict['v_tgt_field'])
        #res += v_tgt.tolist()

        res.append(obs_dict['pelvis']['height'])
        res.append(obs_dict['pelvis']['pitch'])
        res.append(obs_dict['pelvis']['roll'])
        res.append(obs_dict['pelvis']['vel'][0])
        res.append(obs_dict['pelvis']['vel'][1])
        res.append(obs_dict['pelvis']['vel'][2])
        res.append(obs_dict['pelvis']['vel'][3])
        res.append(obs_dict['pelvis']['vel'][4])
        res.append(obs_dict['pelvis']['vel'][5])

        for leg in ['r_leg', 'l_leg']:
            res += obs_dict[leg]['ground_reaction_forces']
            res.append(obs_dict[leg]['joint']['hip_abd'])
            res.append(obs_dict[leg]['joint']['hip'])
            res.append(obs_dict[leg]['joint']['knee'])
            res.append(obs_dict[leg]['joint']['ankle'])
            res.append(obs_dict[leg]['d_joint']['hip_abd'])
            res.append(obs_dict[leg]['d_joint']['hip'])
            res.append(obs_dict[leg]['d_joint']['knee'])
            res.append(obs_dict[leg]['d_joint']['ankle'])
            for MUS in [
                    'HAB', 'HAD', 'HFL', 'GLU', 'HAM', 'RF', 'VAS', 'BFSH',
                    'GAS', 'SOL', 'TA'
            ]:
                res.append(obs_dict[leg][MUS]['f'])
                res.append(obs_dict[leg][MUS]['l'])
                res.append(obs_dict[leg][MUS]['v'])

        res = np.array(res)

        res = self.feature_normalize(
            res, mean=self.mean, std=self.std, duplicate_id=self.duplicate_id)

        remaining_time = (self.step_fea - (self.max_timelimit / 2.0)) / (
            self.max_timelimit / 2.0) * -1.0
        res = np.append(res, remaining_time)

        # target driven (Relative coordinate system)
        current_v_x = obs_dict['pelvis']['vel'][0]  # (+) forward
        current_v_z = obs_dict['pelvis']['vel'][1]  # (+) leftward

        # future vels (0m, 1m, ..., 5m)
        for index in range(5, 11):
            target_v_x, target_v_z = state_desc['v_tgt_field'][0][index][
                5], state_desc['v_tgt_field'][1][index][5]

            diff_vel_x = target_v_x - current_v_x
            diff_vel_z = target_v_z - current_v_z
            diff_vel = np.sqrt(target_v_x ** 2 + target_v_z ** 2) - \
                       np.sqrt(current_v_x ** 2 + current_v_z ** 2)
            res = np.append(
                res, [diff_vel_x / 5.0, diff_vel_z / 5.0, diff_vel / 5.0])

        # current relative target theta
        target_v_x, target_v_z = state_desc['v_tgt_field'][0][5][
            5], state_desc['v_tgt_field'][1][5][5]
        target_theta = math.atan2(target_v_z, target_v_x)

        diff_theta = target_theta

        res = np.append(res, [diff_theta / np.pi])

        return res


def rotate_frame(x, y, theta):
    x_rot = np.cos(theta) * x - np.sin(theta) * y
    y_rot = np.sin(theta) * x + np.cos(theta) * y
    return x_rot, y_rot
